# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
@file advection.py
RestrictionFilter operator generator.
"""
import numpy as np
from hysop.constants import Implementation
from hysop.methods import Remesh
from hysop.numerics.remesh.remesh import RemeshKernel
from hysop.tools.io_utils import IOParams
from hysop.tools.htypes import check_instance, to_list, first_not_None, InstanceOf
from hysop.tools.numpywrappers import npw
from hysop.tools.decorators import debug
from hysop.tools.numerics import find_common_dtype
from hysop.tools.spectral_utils import SpectralTransformUtils
from hysop.tools.method_utils import PolynomialInterpolationMethod
from hysop.fields.continuous_field import Field, ScalarField
from hysop.parameters.scalar_parameter import ScalarParameter
from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors
from hysop.core.graph.node_generator import ComputationalGraphNodeGenerator
from hysop.core.graph.computational_node_frontend import ComputationalGraphNodeFrontend
from hysop.core.memory.memory_request import MemoryRequest
from hysop.operator.base.spectral_operator import SpectralOperatorBase
[docs]
class SpatialFilterBase:
"""
Common base implementation for lowpass spatial filtering: small grid -> coarse grid
"""
def __new__(cls, input_field, output_field, input_topo, output_topo, **kwds):
return super().__new__(cls, input_fields=None, output_fields=None, **kwds)
def __init__(self, input_field, output_field, input_topo, output_topo, **kwds):
check_instance(input_field, ScalarField)
check_instance(output_field, ScalarField)
check_instance(input_topo, CartesianTopologyDescriptors)
check_instance(output_topo, CartesianTopologyDescriptors)
super().__init__(
input_fields={input_field: input_topo},
output_fields={output_field: output_topo},
**kwds,
)
Fin = input_field
Fout = output_field
assert Fin.dim == Fout.dim
assert (Fin.lboundaries == Fout.lboundaries).all()
assert (Fin.rboundaries == Fout.rboundaries).all()
assert (Fin.periodicity == Fout.periodicity).all()
self.Fin = Fin
self.Fout = Fout
self.dim = Fin.dim
self.dtype = find_common_dtype(Fin.dtype, Fout.dtype)
self.iratio = None # will be set in get_field_requirements
self.grid_ratio = None # will be set in discretize
[docs]
@debug
def discretize(self):
if self.discretized:
return
super().discretize()
dFin = self.get_input_discrete_field(self.Fin)
dFout = self.get_output_discrete_field(self.Fout)
grid_ratio = dFin.topology_state.transposed(self.iratio)
self.dFin = dFin
self.dFout = dFout
self.grid_ratio = grid_ratio
[docs]
@classmethod
def supports_multiple_field_topologies(cls):
return True
[docs]
@classmethod
def supports_mpi(cls):
return True
[docs]
class RestrictionFilterBase(SpatialFilterBase):
[docs]
@debug
def get_field_requirements(self):
requirements = super().get_field_requirements()
dim = self.Fin.dim
Fin_topo, Fin_requirements = requirements.get_input_requirement(self.Fin)
try:
Fin_dx = Fin_topo.space_step
except AttributeError:
Fin_dx = Fin_topo.mesh.space_step
Fout_topo, Fout_requirements = requirements.get_output_requirement(self.Fout)
try:
Fout_dx = Fout_topo.space_step
except AttributeError:
Fout_dx = Fout_topo.mesh.space_step
ratio = Fout_dx / Fin_dx
msg = f"Destination grid is finer than source grid: {ratio}"
assert (ratio >= 1.0).all(), msg
iratio = ratio.astype(npw.int32)
msg = f"Grid ratio is not an integer on at least one axis: {ratio}"
assert (ratio == iratio).all(), msg
self.iratio = tuple(iratio.tolist())
return requirements
[docs]
class InterpolationFilterBase(SpatialFilterBase):
[docs]
@debug
def get_field_requirements(self):
requirements = super().get_field_requirements()
dim = self.Fin.dim
Fin_topo, Fin_requirements = requirements.get_input_requirement(self.Fin)
try:
Fin_dx = Fin_topo.space_step
except AttributeError:
Fin_dx = Fin_topo.mesh.space_step
Fout_topo, Fout_requirements = requirements.get_output_requirement(self.Fout)
try:
Fout_dx = Fout_topo.space_step
except AttributeError:
Fout_dx = Fout_topo.mesh.space_step
ratio = Fin_dx / Fout_dx
msg = f"Source grid is finer than destination grid: {ratio}"
assert (ratio >= 1.0).all(), msg
iratio = ratio.astype(npw.int32)
msg = f"Grid ratio is not an integer on at least one axis: {ratio}"
assert (ratio == iratio).all(), msg
self.iratio = tuple(iratio.tolist())
return requirements
[docs]
class SpectralRestrictionFilterBase(RestrictionFilterBase, SpectralOperatorBase):
"""
Base implementation for lowpass spatial filtering: small grid -> coarse grid
using the spectral method.
"""
@debug
def __new__(cls, plot_input_energy=None, plot_output_energy=None, **kwds):
return super().__new__(cls, **kwds)
@debug
def __init__(self, plot_input_energy=None, plot_output_energy=None, **kwds):
"""
Initialize a SpectralRestrictionFilterBase.
Parameters
----------
plot_input_energy: IOParams, optional, defaults to None
Plot input field energy in a custom file.
plot_output_energy: IOParams, optional, defaults to None
Plot output field energy in a custom file.
Notes
-----
IOParams filename is formatted before being used:
{fname} is replaced with field name
{ite} is replaced with simulation iteration id
If None is passed, no plots are generated.
"""
check_instance(plot_input_energy, IOParams, allow_none=True)
check_instance(plot_output_energy, IOParams, allow_none=True)
super().__init__(**kwds)
Fin, Fout = self.Fin, self.Fout
# check that boundary conditions are matching
msg = (
"Input field {l}boundaries {} mismatch with output field {l}boundaries {}."
)
assert (Fin.lboundaries == Fout.lboundaries).all(), msg.format(
Fin.lboundaries, Fout.lboundaries, l="l"
)
assert (Fin.rboundaries == Fout.rboundaries).all(), msg.format(
Fin.rboundaries, Fout.rboundaries, l="r"
)
# build spectral transforms
tg_fine = self.new_transform_group(mem_tag="FINE")
tg_coarse = self.new_transform_group(mem_tag="COARSE")
Ft = tg_fine.require_forward_transform(
Fin, custom_output_buffer="auto", plot_energy=plot_input_energy
)
Bt = tg_coarse.require_backward_transform(
Fout, custom_input_buffer="B0", plot_energy=plot_output_energy
)
self.tg_fine = tg_fine
self.tg_coarse = tg_coarse
self.Ft = Ft
self.Bt = Bt
[docs]
@debug
def discretize(self):
if self.discretized:
return
super().discretize()
dFin, dFout = self.dFin, self.dFout
msg = "Compute resolution of coarse mesh {}::{} is greater than compute resolution of fine mesh {}::{}."
msg = msg.format(
self.Fin.name,
dFin.compute_resolution,
self.Fout.name,
dFout.compute_resolution,
)
assert (dFin.compute_resolution >= dFout.compute_resolution).all(), msg
[docs]
def setup(self, work):
super().setup(work)
self.FIN = self.Ft.output_buffer
self.FOUT = self.Bt.input_buffer
self.fslices = self._generate_filter_slices()
self.scaling = self._compute_scaling_coefficient()
def _generate_filter_slices(self):
src_slices = [[]]
dst_slices = [[]]
transforms = tuple(self.Ft.transforms[i] for i in self.Ft.output_axes)
for N, n, tr in zip(self.FIN.shape, self.FOUT.shape, transforms):
assert len(src_slices) == len(dst_slices)
assert n <= N
if SpectralTransformUtils.is_C2C(tr):
left_src_slices = [l[:] for l in src_slices]
right_src_slices = [l[:] for l in src_slices]
lsrc = slice(0, (n + 1) // 2, 1)
rsrc = slice(N - n // 2, N, 1)
for lslc, rslc in zip(left_src_slices, right_src_slices):
lslc.append(lsrc)
rslc.append(rsrc)
src_slices = left_src_slices + right_src_slices
left_dst_slices = [l[:] for l in dst_slices]
right_dst_slices = [l[:] for l in dst_slices]
ldst = slice(0, (n + 1) // 2, 1)
rdst = slice(n - n // 2, n, 1)
for lslc, rslc in zip(left_dst_slices, right_dst_slices):
lslc.append(ldst)
rslc.append(rdst)
dst_slices = left_dst_slices + right_dst_slices
else:
src = slice(0, n, 1)
dst = slice(0, n, 1)
for src_slc, dst_slc in zip(src_slices, dst_slices):
src_slc.append(src)
dst_slc.append(dst)
src_slices = tuple(tuple(_) for _ in src_slices)
dst_slices = tuple(tuple(_) for _ in dst_slices)
return (src_slices, dst_slices)
def _compute_scaling_coefficient(self):
# scaling can depend on the fft backend so we bruteforce it
# in every backend
msg = "_compute_scaling_coefficient() has not been implemented for operator {}."
raise NotImplementedError(msg.format(type(self)))
[docs]
class RemeshRestrictionFilterBase(RestrictionFilterBase):
"""
Base implementation for lowpass spatial filtering: small grid -> coarse grid
using remeshing kernels.
"""
__default_method = {
Remesh: Remesh.L2_1,
}
__available_methods = {
Remesh: (InstanceOf(Remesh), InstanceOf(RemeshKernel)),
}
[docs]
@classmethod
def default_method(cls):
dm = super().default_method()
dm.update(cls.__default_method)
return dm
[docs]
@classmethod
def available_methods(cls):
am = super().available_methods()
am.update(cls.__available_methods)
return am
[docs]
@debug
def handle_method(self, method):
super().handle_method(method)
remesh_kernel = method.pop(Remesh)
if isinstance(remesh_kernel, Remesh):
remesh_kernel = RemeshKernel.from_enum(remesh_kernel)
self.remesh_kernel = remesh_kernel
@classmethod
def _remesh_ghosts(cls, remesh_kernel):
"""Return the minimum number of ghosts for remeshed scalars."""
assert remesh_kernel.n >= 1, "Bad remeshing kernel."
if remesh_kernel.n > 1:
assert remesh_kernel.n % 2 == 0, "Odd remeshing kernel moments."
min_ghosts = int(remesh_kernel.n // 2) + 1
return min_ghosts
[docs]
@debug
def get_field_requirements(self):
requirements = super().get_field_requirements()
iratio = self.iratio
remesh_ghosts = self._remesh_ghosts(self.remesh_kernel)
fine_grid_ghosts = tuple(np.multiply(iratio, remesh_ghosts) - 1)
Fin_topo, Fin_requirements = requirements.get_input_requirement(self.Fin)
Fin_requirements.min_ghosts = fine_grid_ghosts
self.remesh_ghosts = remesh_ghosts
self.fine_grid_ghosts = fine_grid_ghosts
return requirements
[docs]
def compute_weights(self, iratio, product=True):
iratio_np = np.asarray(iratio)
assert (iratio_np >= 1).all()
remesh_kernel = self.remesh_kernel
p = remesh_kernel.n // 2 + 1
shape = 2 * p * iratio_np - 1
weights = npw.zeros(dtype=npw.float64, shape=shape)
nz_weights = {}
for idx in npw.ndindex(*shape):
X = (npw.asarray(idx, dtype=npw.float64) + 1) / iratio_np - p
if product:
W = npw.prod(remesh_kernel(X))
else:
# this does not seem to work because the sum of the weights is ~1e-5
R = npw.sqrt(npw.dot(X, X))
W = remesh_kernel(R)
weights[idx] = W
if W != 0:
nz_weights[idx] = W
Ws = weights.sum()
weights = weights / Ws
nz_weights = {k: v / Ws for (k, v) in nz_weights.items()}
assert abs(weights.sum() - 1.0) < 1e-8, weights.sum()
assert abs(npw.sum(nz_weights.values()) - 1.0) < 1e-8, npw.sum(
nz_weights.values()
)
self.weights = weights
self.nz_weights = nz_weights
[docs]
@debug
def discretize(self):
if self.discretized:
return
super().discretize()
dFin, dFout = self.dFin, self.dFout
grid_ratio = self.grid_ratio
self.compute_weights(grid_ratio)
remesh_ghosts = self.remesh_ghosts
fine_grid_ghosts = np.multiply(grid_ratio, remesh_ghosts) - 1
fin = dFin.sdata[dFin.local_slices(ghosts=fine_grid_ghosts)]
fout = dFout.compute_buffers[0]
self.fin, self.fout = fin, fout
[docs]
class SubgridRestrictionFilterBase(RestrictionFilterBase):
"""
Base implementation for lowpass spatial filtering: small grid -> coarse grid
using subgrid
"""
[docs]
@debug
def discretize(self):
if self.discretized:
return
super().discretize()
dFin, dFout = self.dFin, self.dFout
grid_ratio = self.grid_ratio
view = tuple(slice(None, None, r) for r in grid_ratio)
fin = dFin.compute_buffers[0][view]
fout = dFout.compute_buffers[0]
msg = "Something went wrong during slicing: fin.shape={}, fout.shape={}"
msg = msg.format(fin.shape, fout.shape)
assert fin.shape == fout.shape, msg
assert npw.prod(grid_ratio) == npw.prod(self.iratio), msg
self.fin, self.fout = fin, fout
[docs]
class PolynomialInterpolationFilterBase(
PolynomialInterpolationMethod, InterpolationFilterBase
):
"""
Base implementation for polynomial interpolation.
"""
[docs]
@debug
def get_field_requirements(self):
reqs = super().get_field_requirements()
required_input_ghosts = np.add(
self.polynomial_interpolator.ghosts, self.Fin.periodicity
)
Fin_topo, Fin_requirements = reqs.get_input_requirement(self.Fin)
Fin_requirements.min_ghosts = required_input_ghosts
self.required_input_ghosts = required_input_ghosts
return reqs
[docs]
def discretize(self):
if self.discretized:
return
super().discretize()
dFin, dFout = self.dFin, self.dFout
ghosts = self.dFin.topology_state.transposed(self.required_input_ghosts)
psi = self.polynomial_interpolator.generate_subgrid_interpolator(
grid_ratio=self.grid_ratio
)
self.subgrid_interpolator = psi
self.fin = dFin.sdata[dFin.local_slices(ghosts=ghosts)].handle
self.fout = dFout.sdata[dFout.compute_slices].handle
self.iter_shape = self.dFin.compute_resolution + 1 - self.dFin.periodicity
[docs]
class PolynomialRestrictionFilterBase(
PolynomialInterpolationMethod, RestrictionFilterBase
):
"""
Base implementation for polynomial interpolation.
"""
[docs]
@debug
def get_field_requirements(self):
reqs = super().get_field_requirements()
iratio = self.iratio
pghosts = self.polynomial_interpolator.ghosts
ghosts = np.add(np.multiply(iratio, np.add(pghosts, 1)), -1)
Fin_topo, Fin_requirements = reqs.get_input_requirement(self.Fin)
Fin_requirements.min_ghosts = ghosts
self.required_input_ghosts = ghosts
return reqs
[docs]
def discretize(self):
if self.discretized:
return
super().discretize()
dFin, dFout = self.dFin, self.dFout
ghosts = self.dFin.topology_state.transposed(self.required_input_ghosts)
psr = self.polynomial_interpolator.generate_subgrid_interpolator(
grid_ratio=self.grid_ratio
).generate_subgrid_restrictor()
assert all(psr.ghosts == ghosts)
self.subgrid_restrictor = psr
self.fin = dFin.sdata[dFin.local_slices(ghosts=ghosts)].handle
self.fout = dFout.sdata[dFout.compute_slices].handle
self.iter_shape = self.dFout.compute_resolution